-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow passing non-default modules to pipeline #188
Conversation
Override modules are recognized and replaced in the pipeline. However, no check is performed about mismatched classes yet. This is because the override module is already instantiated and we have no library or class name to compare against.
The documentation is not available anymore as the PR was closed or merged. |
I tried the following (as a quick hack before refactoring): for name, module in passed_class_obj.items():
# TODO: verify that the module class belongs to one of the supported classes
library_name, class_name = config_dict[name]
library = importlib.import_module(library_name)
loadable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in loadable_classes.keys()}
for class_name, class_candidate in class_candidates.items():
if isinstance(module, class_candidate):
init_kwargs[name] = module
# Remove it even if not found, as it's not appropriate
init_dict.pop(name) However, if we pass a scheduler instance to If we want to really verify this, I think we should create a more fine-grained mapping from module keys to supported classes, instead of checking all the loadable/importable classes in the library. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, although I would add at least one test too 😅
Just added a test, all other tests pass |
@anton-l - feel free to merge and then maybe also add it manually to the release notes quickly :-) |
Merging! |
* Allow passing non-default modules to pipeline. Override modules are recognized and replaced in the pipeline. However, no check is performed about mismatched classes yet. This is because the override module is already instantiated and we have no library or class name to compare against. * up * add test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Addresses #183.
Override modules are recognized and replaced in the pipeline. However, no check is performed about mismatched classes yet. This is because the override module is already instantiated (see https://github.com/huggingface/diffusers/blob/main/src/diffusers/configuration_utils.py#L223), and
init_dict
in https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipeline_utils.py#L151 no longer has alibrary_name
or aclass_name
- just the instantiated module.I'm looking at a way to detect a class mismatch so we can fail more gracefully.